import json
import os
import logging
from datetime import datetime
from jinja2 import Environment
from ..utils.read_source_code import ReadSourceCode
from ..prompts.report_prompt import rep_prompt
from ..client.MakeConfig import Configuration
from ..client.MCPClient import MCPClient
from ..llm.LLM import LLMClient

class Reporter:
    def __init__(self, testpath, config_path = None, detailed = False, api_key = None):
        self.testpath = testpath
        self.config_path = config_path
        self.detailed = detailed
        self.api_key = api_key

    async def run(self):        
        with open(self.testpath, 'r', encoding='utf-8') as file:
            test_cases = json.load(file)
        self.server_name = self.testpath.split('/')[-2].split('_2025')[0]
        self.foler_name =os.path.dirname(self.testpath)
        report = self.generate_report(test_cases)
        self.save_to_file(report)
        self.print_report(report)

        if self.detailed:
            Config_class = Configuration()
            self.config = Config_class.load_config(self.config_path)
            self.llm = LLMClient(self.api_key)
            # 生成详细报告
            await self.generate_report_detail(test_cases, report)

    def generate_report(self, res):
        """
        生成验证结果报告
        
        参数:
            res: 验证结果列表，每个元素包含测试用例的详细信息
            
        返回:
            结构化的报告字典
        """
        report = {
            "tools": {},  
            "summary": { 
                "total_cases": 0,
                "success_cases": 0,
                "failure_cases": 0,
                "tool_validation_pass_rate": 0.0,
                "eval_validation_pass_rate": 0.0,
                "tools_summary": {}  
            }
        }
        
        for case in res:
            tool_name = case["toolName"]
            case_id = case["id"]
            
            if tool_name not in report["tools"]:
                report["tools"][tool_name] = {
                    "cases": {},
                    "summary": {
                        "total_cases": 0,
                        "success_cases": 0,
                        "tool_validation_pass": 0,
                        "eval_validation_pass": 0
                    }
                }
            
            rule_results = case["validation_tool"]["rule_results"]
            total_rules = len(rule_results) if isinstance(rule_results, list) else 0
            
            passed_rules = 0
            if total_rules > 0 and isinstance(rule_results, list):
                passed_rules = sum(1 for rule in rule_results if rule.get("rule_passed", False))
            
            # 判断是否超过一半通过
            tool_passed = False
            tool_pass_status = f"{passed_rules}/{total_rules}"
            if total_rules > 0:
                tool_passed = passed_rules / total_rules >= 0.5
            
            eval_passed = case["validation_eval"]["passed"]
            
            case_result = {
                "id": case_id,
                "input": case["input"],
                "description": case["description"],
                "query": case["query"],
                "expect": case["expect"],
                "validation_tool": {
                    "passed": tool_passed,
                    "rule_pass_status": tool_pass_status,
                    "total_rules": total_rules,
                    "passed_rules": passed_rules,
                    "pass_rate": passed_rules / total_rules if total_rules > 0 else 0.0,
                    "rule_results": rule_results
                },
                "validation_eval": {
                    "passed": eval_passed,
                    "message": case["validation_eval"]["message"]
                },
                "overall_pass": tool_passed and eval_passed and case["expect"] == "success"
            }
            
            report["tools"][tool_name]["cases"][case_id] = case_result
            
            report["tools"][tool_name]["summary"]["total_cases"] += 1
            if case["expect"] == "success":
                report["tools"][tool_name]["summary"]["success_cases"] += 1
            if tool_passed:
                report["tools"][tool_name]["summary"]["tool_validation_pass"] += 1
            if eval_passed:
                report["tools"][tool_name]["summary"]["eval_validation_pass"] += 1
        
        # 总体统计
        total_cases = 0
        success_cases = 0
        total_tool_validation = 0
        passed_tool_validation = 0
        total_eval_validation = 0
        passed_eval_validation = 0
        
        for tool_name, tool_data in report["tools"].items():
            total_cases += tool_data["summary"]["total_cases"]
            success_cases += tool_data["summary"]["success_cases"]
            
            total_tool_validation += tool_data["summary"]["total_cases"]
            passed_tool_validation += tool_data["summary"]["tool_validation_pass"]
            
            total_eval_validation += tool_data["summary"]["total_cases"]
            passed_eval_validation += tool_data["summary"]["eval_validation_pass"]
            
            # 计算每个工具的通过率
            report["summary"]["tools_summary"][tool_name] = {
                "total_cases": tool_data["summary"]["total_cases"],
                "success_rate": tool_data["summary"]["success_cases"] / tool_data["summary"]["total_cases"] * 100 if tool_data["summary"]["total_cases"] > 0 else 0,
                "tool_validation_pass_rate": tool_data["summary"]["tool_validation_pass"] / tool_data["summary"]["total_cases"] * 100 if tool_data["summary"]["total_cases"] > 0 else 0,
                "eval_validation_pass_rate": tool_data["summary"]["eval_validation_pass"] / tool_data["summary"]["total_cases"] * 100 if tool_data["summary"]["total_cases"] > 0 else 0
            }
        
        report["summary"]["total_cases"] = total_cases
        report["summary"]["success_cases"] = success_cases
        report["summary"]["failure_cases"] = total_cases - success_cases
        
        # 计算总体通过率
        if total_tool_validation > 0:
            report["summary"]["tool_validation_pass_rate"] = passed_tool_validation / total_tool_validation * 100
        if total_eval_validation > 0:
            report["summary"]["eval_validation_pass_rate"] = passed_eval_validation / total_eval_validation * 100
        return report
    
    async def generate_report_detail(self, res, report):
        """
        生成详细验证结果报告
        
        参数:
            res: 验证结果列表，每个元素包含测试用例的详细信息
            report: 验证报告字典
        """
        
        srv_config = self.config["mcpServers"].get(self.server_name,"")

        server = MCPClient(self.server_name, srv_config)
        if not srv_config:
            logging.error(f"生成详细的测试报告需要输入正确的MCP Config文件")
        
        await server.initialize()  
        
        tools = await server.list_tools()
        if not tools:
            Warning('No tools found in the MCP server.')
    
        tools_info = report["tools"]
        readsc = ReadSourceCode(self.config_path)
        tool_functions = readsc.get_code(self.server_name)
        jinja_env = Environment()
        summary_file_path = os.path.join(
            self.foler_name, 
            f"{self.server_name}_all_tools_analysis.md"
        )

        if not os.path.exists(summary_file_path):
            try:
                with open(summary_file_path, 'w', encoding='utf-8') as f:
                    f.write(f"# MCP服务器{self.server_name}工具分析汇总报告\n")
                    f.write(f"- 生成时间：{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                    f.write(f"- 工具数量：{len(tools_info)}\n")
                logging.info(f"汇总分析报告已创建：{summary_file_path}")
            except Exception as e:
                logging.error(f"创建汇总分析报告失败：{str(e)}", exc_info=True)
                return
        
        for tool_name, tool_data in tools_info.items():
            tool_failed_details = []
            eval_failed_details = []
            ts = tool_data["summary"]
            if ts["tool_validation_pass"] == ts["total_cases"] and ts["eval_validation_pass"] == ts["total_cases"]:
                continue
            tool = self.get_tool(tool_name, tools)
            for case_id, case in tool_data["cases"].items():
                val_case = self.query_id_from_vallist(case_id, res)
                if not val_case:
                    continue
                if not case["validation_tool"]["passed"]:
                    rule_not_passed =[rule_r for rule_r in val_case["validation_tool"]["rule_results"] if not rule_r["rule_passed"]]
                    tool_failed_details.append({"input": case["input"],
                                                "description": case["description"],
                                                "expect": case["expect"],
                                                "env_script": val_case.get("env_script",""),
                                                "tool_output": val_case["validation_tool"]["output"],
                                                "rule_not_passed": rule_not_passed
                                                })
                if not case["validation_eval"]["passed"]: 
                    eval_failed_details.append({
                        "query": case["query"],
                        "eval_output": val_case["validation_eval"]["output"],
                        "message": val_case["validation_eval"]["message"]
                    })
            
            tool_function = tool_functions[tool_name]
            if len(tool_failed_details)>4:
                tool_failed_details = tool_failed_details[:4]
            if len(eval_failed_details)>4:
                eval_failed_details = eval_failed_details[:4]
            
            input_properties = {}
            if tool.input_schema and hasattr(tool.input_schema, 'properties'):
                input_properties = json.dumps(tool.input_schema["properties"], indent=2)
            else:
                input_properties = "{}"
                
            rep_template = jinja_env.from_string(rep_prompt)
            rep_vars = {
                    "tool_name": tool_name,
                    "server_name":  self.server_name,  
                    "tool_failed_details": tool_failed_details,
                    "eval_failed_details": eval_failed_details,
                    "tool_description": tool.description,
                    "input_properties": input_properties,
                    "tool_function": tool_function,
                }
            rep_prompt_formatted = rep_template.render(**rep_vars)
            rep_output = self.llm.get_response([{"role": "user", "content": rep_prompt_formatted}])

            try:
                with open(summary_file_path, 'a', encoding='utf-8') as f: 
                    f.write(f"## 工具：{tool_name}\n")
                    f.write(f"- 失败用例数（工具验证）：{len(tool_failed_details)}\n")
                    f.write(f"- 失败用例数（评估验证）：{len(eval_failed_details)}\n")
                    f.write(f"\n")
                    f.write(rep_output) 
                logging.info(f"工具 {tool_name} 的分析结果已追加到汇总报告：{summary_file_path}")
            except Exception as e:
                logging.error(f"追加工具 {tool_name} 的分析结果失败：{str(e)}", exc_info=True)
        
        await server.cleanup()

    def query_id_from_vallist(self, case_id, testcases):
        for case in testcases:
            if case["id"] == case_id:
                return case
        return None

    def get_tool(self, tool_name, tools):
        for tool in tools:
            if tool.name == tool_name:
                return tool
        return None
    
    def save_to_file(self, report):
        """保存报告到文件"""
        report_path = os.path.join(self.foler_name, "report.json")
        with open(report_path, 'w', encoding='utf-8') as file:
            json.dump(report, file, ensure_ascii=False, indent=4)
        print(f"报告已保存到 {report_path}")

    def print_report(self, report):
        """打印报告的简洁版本"""
        print("=" * 50)
        print(f"{self.server_name} 验证结果汇总报告")
        print("=" * 50)
        print(f"总测试用例数: {report['summary']['total_cases']}")
        print(f"预期成功的用例数: {report['summary']['success_cases']}")
        print(f"预期失败的用例数: {report['summary']['failure_cases']}")
        print(f"工具验证总体通过率: {report['summary']['tool_validation_pass_rate']:.2f}%")
        print(f"评估验证总体通过率: {report['summary']['eval_validation_pass_rate']:.2f}%")
        
        print("\n" + "-" * 50)
        print("各工具统计:")
        print("-" * 50)
        for tool_name, tool_stats in report["summary"]["tools_summary"].items():
            print(f"\n工具: {tool_name}")
            print(f"  总用例数: {tool_stats['total_cases']}")
            # print(f"  成功用例占比: {tool_stats['success_rate']:.2f}%")
            print(f"  工具验证通过率: {tool_stats['tool_validation_pass_rate']:.2f}%")
            print(f"  评估验证通过率: {tool_stats['eval_validation_pass_rate']:.2f}%")

        logpath = os.path.join(self.foler_name, f"{self.server_name}_report_summary.md")
        with open(logpath, 'w', encoding='utf-8') as f:
            # 写入标题和分隔线
            f.write("=" * 50 + "\n")
            f.write(f"{self.server_name} 验证结果汇总报告\n")
            f.write("=" * 50 + "\n")
            
            # 写入总体统计
            f.write(f"总测试用例数: {report['summary']['total_cases']}\n")
            f.write(f"预期成功的用例数: {report['summary']['success_cases']}\n")
            f.write(f"预期失败的用例数: {report['summary']['failure_cases']}\n")
            f.write(f"工具验证总体通过率: {report['summary']['tool_validation_pass_rate']:.2f}%\n")
            f.write(f"评估验证总体通过率: {report['summary']['eval_validation_pass_rate']:.2f}%\n")
            
            # 写入工具统计部分
            f.write("\n" + "-" * 50 + "\n")
            f.write("各工具统计:\n")
            f.write("-" * 50 + "\n")
            
            # 循环写入每个工具的统计信息
            for tool_name, tool_stats in report["summary"]["tools_summary"].items():
                f.write(f"\n工具: {tool_name}\n")
                f.write(f"  总用例数: {tool_stats['total_cases']}\n")
                f.write(f"  工具验证通过率: {tool_stats['tool_validation_pass_rate']:.2f}%\n")
                f.write(f"  评估验证通过率: {tool_stats['eval_validation_pass_rate']:.2f}%\n")
        print(f"报告已成功写入文件: {logpath}")


        